# mypy: allow-untyped-defs
import enum
from typing import Any, Callable, overload

import torch
from torch.distributed.algorithms.join import Joinable, JoinHook
from torch.optim import Optimizer

class _ZeROJoinHook(JoinHook):
    zero: Any = ...
    def __init__(self, zero: Any) -> None: ...
    def main_hook(self) -> None: ...

class _DDPBucketAssignment:
    bucket_index: int
    parameters: list[torch.Tensor]
    offset: int
    device: torch.device
    tensor: torch.Tensor | None

class _OverlapStatus(enum.IntEnum):
    UNINITIALIZED: int = ...
    DDP_HAS_REBUILT_BUCKETS: int = ...
    INITIALIZED: int = ...

class _OverlapInfo:
    status: Any = ...
    params_per_bucket: Any = ...
    params_per_rank: Any = ...
    offsets: Any = ...
    broadcast_handles: Any = ...
    bucket_index_to_future: Any = ...
    bucket_index_to_bucket: Any = ...
    bucket_indices_seen: Any = ...
    assigned_ranks_per_bucket: list[set[int]] = ...
    total_size: int = ...
    shard_buckets: bool = ...
    def __init__(self) -> None: ...
    def wait_for_broadcasts(self) -> None: ...
    def clear_per_iter_info(self) -> None: ...

class ZeroRedundancyOptimizer(Optimizer, Joinable):
    functional_optim_map: Any = ...
    initialized: bool = ...
    process_group: Any = ...
    world_size: int = ...
    rank: int = ...
    global_rank: int = ...
    parameters_as_bucket_view: bool = ...
    optim: Any = ...
    _device_to_device_index: dict[torch.device, int] = ...
    _overlap_with_ddp: bool = ...
    _overlap_info: _OverlapInfo = ...
    _buckets: list[list[torch.Tensor]] = ...
    _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ...
    def __init__(
        self,
        params: Any,
        optimizer_class: type[Optimizer],
        process_group: Any | None = ...,
        parameters_as_bucket_view: bool = ...,
        overlap_with_ddp: bool = ...,
        **defaults: Any,
    ) -> None: ...
    def add_param_group(self, param_group: dict[str, Any]) -> None: ...
    def consolidate_state_dict(self, to: int = ...) -> None: ...
    @overload
    def step(self, closure: None = ..., **kwargs: Any) -> None: ...
    @overload
    def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ...
    def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
    def state_dict(self) -> dict[str, Any]: ...
    def _local_step(
        self,
        gradients: list[torch.Tensor | None] | None = None,
        closure: Callable[[], float] | None = None,
        **kwargs: Any,
    ) -> float | None: ...
    def _get_assigned_rank(self, bucket_index: int) -> int: ...
    def _init_zero_for_overlap(self) -> None: ...
    def join_hook(self, **kwargs): ...
    @property
    def join_device(self) -> torch.device: ...
    def join_process_group(self) -> Any: ...
